import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

# Load data
df = pd.read_csv('./data/motn_fewshot_large.csv')
fr = pd.read_csv('./data/motn_fewshot_base.csv')
llamazs = pd.read_csv('./data/freedom_test.csv')
llamafs = pd.read_csv('./data/llama_motn_25shot.csv')

llamazs_f1 = f1_score(llamazs['entailment'], llamazs['llama'])
llamafs_f1 = f1_score(llamafs['entailment'], llamafs['llama_25shot'])

# Define metric and text size
metric = 'f1'
textsize = 24

# Define the x positions you want to use
x_positions = [0, 10, 25, 50, 100]

# Create the figure with two subplots (sharing the y axis)
fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharey=True)
custom_palette = sns.color_palette("colorblind", 2)

# ---------------------
# Left subplot: DEBATE Base (fr)
# ---------------------
# Prepare data for each x value
data_fr = [fr[fr['n'] == pos][metric].dropna().values for pos in x_positions]
# Create a boxplot using Matplotlib's boxplot with custom positions
bp_fr = axes[0].boxplot(data_fr, positions=x_positions, widths=5, patch_artist=True, showfliers=True)

# Get median values for the dots
medians_fr = [np.median(data) for data in data_fr]
# Add dots at median values
axes[0].scatter(x_positions, medians_fr, color='black', s=50, zorder=3)

# Color the boxes using the custom palette
for box in bp_fr['boxes']:
    box.set_facecolor(custom_palette[1])

axes[0].set_title('DEBATE Base', fontsize=textsize, fontweight='bold')
axes[0].set_xlim(-5, 105)
axes[0].set_xticks(x_positions)
axes[0].set_xticklabels(x_positions, rotation=0)
axes[0].set_xlabel('Training Samples', fontsize=textsize)
axes[0].set_ylabel('F1', fontsize=textsize, fontweight='bold')
axes[0].set_yticks([.3, .4, .5, .6, .7, .8, .9])
axes[0].set_yticklabels(['30%', '40%', '50%', '60%', '70%', '80%', '90%'])
axes[0].tick_params(axis='both', labelsize=textsize - 4)
axes[0].grid(False)

# ---------------------
# Right subplot: DEBATE Large (df)
# ---------------------
# Prepare data for each x value
data_df = [df[df['n'] == pos][metric].dropna().values for pos in x_positions]
# Create a boxplot with the same approach
bp_df = axes[1].boxplot(data_df, positions=x_positions, widths=5, patch_artist=True, showfliers=True)

# Get median values for the dots
medians_df = [np.median(data) for data in data_df]
# Add dots at median values
axes[1].scatter(x_positions, medians_df, color='black', s=50, zorder=3)

# Color the boxes using the custom palette
for box in bp_df['boxes']:
    box.set_facecolor(custom_palette[0])

axes[1].set_title('DEBATE Large', fontsize=textsize, fontweight='bold')
axes[1].set_xlim(-5, 105)
axes[1].set_xticks(x_positions)
axes[1].set_xticklabels(x_positions, rotation=0)
axes[1].set_xlabel('Training Samples', fontsize=textsize)
axes[1].tick_params(axis='both', labelsize=textsize - 4)
axes[1].grid(False)

# ---------------------
# Add common reference lines to both subplots
# ---------------------
for ax in axes:
    ax.axhline(y=llamazs_f1, color='black', linestyle='--', linewidth=2, alpha=0.5)
    ax.axhline(y=llamafs_f1, color='black', linestyle='--', linewidth=2, alpha=0.5)

for bp in [bp_fr, bp_df]:
    for median in bp['medians']:
        median.set(color='black', linewidth=2)

# ---------------------
# Add annotations to the right subplot
# ---------------------
axes[1].text(axes[1].get_xlim()[1]-2, llamafs_f1+.02, 'Llama 3.1 8B (25-Shot)', va='center', ha='right', size=textsize)
axes[1].text(axes[1].get_xlim()[1]-2, llamazs_f1-.02, 'Llama 3.1 8B (0-Shot)', va='center', ha='right', size=textsize)

# Remove individual x-axis labels
axes[0].set_xlabel("")
axes[1].set_xlabel("")

plt.tight_layout(rect=[0, 0.03, 1, 1])  # Leave 5% space at the bottom
fig.text(0.54, 0.02, "Training Samples", ha='center', fontsize=textsize, fontweight = 'bold')
# Adjust layout and save the figure
plt.savefig('./figures/figure_13.png', dpi=300)